import copy

from models.stargan import Generator
from models.stargan import Discriminator
from torch.autograd import Variable
from torchvision.utils import save_image
import torch
import torch.nn.functional as F
import numpy as np
import os
import time
import datetime
import torchvision
import pickle as pkl

from collections import OrderedDict
import wandb
from utils.inb_utils import prepare_data, prepare_data_domains, sg_translate
from utils.metrics import part_wd, evaluate_fid_score


torch.set_default_dtype(torch.float64)


def denorm(x):
    """Convert the range from [-1, 1] to [0, 1]."""
    out = (x + 1) / 2
    return out.clamp_(0, 1)

def label2onehot(labels, dim):
    """Convert label indices to one-hot vectors."""
    batch_size = labels.size(0)
    out = torch.zeros(batch_size, dim)
    out[np.arange(batch_size), labels.long()] = 1
    return out

def eval_fid_wd_init(x, d, domain_list, mat=True, fid=True, wd=True):
    wd_mat = torch.zeros(len(domain_list),len(domain_list))
    fid_mat = torch.zeros(len(domain_list), len(domain_list))
    for idx in domain_list:
        xt = x[d==idx]
        xt = denorm(xt)
        dc = d[d==idx]
        for jdx in domain_list:
            xr = x[d == jdx]
            xr = denorm(xr)
            assert torch.max(xr) <= 1 and torch.max(xt) <= 1 and torch.min(xr) >= 0\
                   and torch.min(xt) >= 0, 'Check range of output'
            if wd:
                wd_mat[idx,jdx] = part_wd(xr,xt)
            if fid:
                fid_mat[idx, jdx] = evaluate_fid_score(
                    xr.view(-1, 1, 28, 28).detach().numpy().reshape(xr.shape[0], 28, 28, 1),
                    xt.view(-1, 1, 28, 28).detach().numpy().reshape(xt.shape[0], 28, 28, 1))

    avg_wd, avg_fid = 0, 0
    if wd:
        avg_wd = torch.mean(wd_mat).item()
    if fid:
        avg_fid = torch.mean(fid_mat).item()
    if mat:
        return avg_wd, wd_mat, avg_fid, fid_mat
    else:
        return avg_wd, avg_fid

### definition of wasserstein distance and FID score calculated for each model  ### 

def eval_fid_wd(x, d, sg, domain_list, device, mat=True, fid=True,wd =True):
    wd_mat = torch.zeros(len(domain_list),len(domain_list))
    fid_mat = torch.zeros(len(domain_list), len(domain_list))
    for idx in domain_list:
        xc = x[d==idx]
        dc = d[d==idx]
        dc = label2onehot(dc, 5)
        for jdx in domain_list:
            xr = x[d == jdx]
            xr = denorm(xr)
            jdx_tensor = torch.ones(dc.size(0))*jdx
            jdx_tensor = label2onehot(jdx_tensor, 5)
            xt = sg_translate(sg, xc.to(device), dc.to(device), jdx_tensor.to(device))
            xt = denorm(xt)
            assert torch.max(xr) <= 1 and torch.max(xt) <= 1 and torch.min(xr) >= 0\
                   and torch.min(xt) >= 0, 'Check range of output'
            if wd:
                wd_mat[idx,jdx] = part_wd(xr.cpu(), xt.cpu())
            if fid:
                fid_mat[idx, jdx] = evaluate_fid_score(
                    xr.detach().cpu().numpy().reshape(xr.shape[0], 28, 28, 1),
                    xt.detach().cpu().numpy().reshape(xt.shape[0], 28, 28, 1),
                )


    avg_wd, avg_fid = 0, 0
    if wd:
        avg_wd = torch.mean(wd_mat).item()
    if fid:
        avg_fid = torch.mean(fid_mat).item()
    if mat:
        return avg_wd, wd_mat, avg_fid, fid_mat
    else:
        return avg_wd, avg_fid


class Benchmark(object):
    """Solver for training and testing StarGAN in Fedrated Learning setting"""

    def __init__(self, loader_dict, domain_idx, test_loader, config):
        """Initialize configurations."""

        # Data loader.
        self.loader_dict = loader_dict
        self.test_loader = test_loader

        self.domain_idx = domain_idx

        # fid and wd
        self.get_fid = config.get_fid
        self.get_wd = config.get_wd

        # Model configurations.
        self.c_dim = config.c_dim
        self.image_size = config.image_size
        self.g_conv_dim = config.g_conv_dim
        self.d_conv_dim = config.d_conv_dim
        self.g_repeat_num = config.g_repeat_num
        self.d_repeat_num = config.d_repeat_num
        self.lambda_cls = config.lambda_cls
        self.lambda_rec = config.lambda_rec
        self.lambda_gp = config.lambda_gp

        # Training configurations.
        self.source_domains = config.source_domains
        self.target_domain = config.target_domain
        self.dataset = config.dataset
        self.batch_size = config.batch_size
        self.num_iters = config.num_iters
        self.num_iters_decay = config.num_iters_decay
        self.g_lr = config.g_lr
        self.d_lr = config.d_lr
        self.n_critic = config.n_critic
        self.beta1 = config.beta1
        self.beta2 = config.beta2
        self.resume_iters = config.resume_iters

        # Test configurations.
        self.test_iters = config.test_iters

        # Miscellaneous.
        self.use_tensorboard = config.use_tensorboard
        self.use_wandb = config.use_wandb
        self.device = torch.device(config.device_name if torch.cuda.is_available() else 'cpu')
        self.run_name = config.run_name

        # Directories.
        self.log_dir = config.log_dir
        self.model_save_dir = config.model_save_dir

        # Step size.
        self.sync_step = config.sync_step
        self.log_step = config.log_step
        self.sample_step = config.sample_step
        self.model_save_step = config.model_save_step
        self.lr_update_step = config.lr_update_step
        self.vis_step = config.vis_step

        self.config = config
        self.vis_batch = next(iter(loader_dict[self.source_domains[2]]))[0]

        # Build the model and tensorboard.
        self.build_model()
        # self.init_clients()

        # Setup logger
        # if self.use_tensorboard:
        #     self.build_tensorboard()
        # if self.use_wandb:
        #     self.init_wandb(config)

    def build_model(self):
        """Create a generator and a discriminator."""
        self.G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num)
        self.D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num)

        self.g_optimizer = torch.optim.Adam(self.G.parameters(), self.g_lr, [self.beta1, self.beta2])
        self.d_optimizer = torch.optim.Adam(self.D.parameters(), self.d_lr, [self.beta1, self.beta2])
        self.num_params = 0
        self.num_params += self.print_network(self.G, 'G')
        self.num_params += self.print_network(self.D, 'D')

        self.G.to(self.device)
        self.D.to(self.device)

    def print_network(self, model, name):
        """Print out the network information."""
        num_params = 0
        for p in model.parameters():
            num_params += p.numel()
        print(model)
        print(name)
        print("The number of parameters: {}".format(num_params))
        return num_params

    def restore_model(self, resume_iters):
        """Restore the trained generator and discriminator."""
        print('Loading the trained models from step {}...'.format(resume_iters))
        G_path = os.path.join(self.model_save_dir, '{}_{}-G.ckpt'.format(self.target_domain, resume_iters))
        D_path = os.path.join(self.model_save_dir, '{}_{}-D.ckpt'.format(self.target_domain, resume_iters))
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))
        self.D.load_state_dict(torch.load(D_path, map_location=lambda storage, loc: storage))

    def load_model(self, G_path):
        """Restore the trained generator and discriminator."""
        self.G.load_state_dict(torch.load(G_path, map_location=lambda storage, loc: storage))

    def build_tensorboard(self):
        """Build a tensorboard logger."""
        from logger import Logger
        self.logger = Logger(self.log_dir)

    def init_wandb(self, args):
        '''Initialize wandb project'''
        wandb.init(project=args.project, entity=args.entity, config=args, name=args.run_name)

    def update_lr(self, g_lr, d_lr):
        """Decay learning rates of the generator and discriminator."""
        for param_group in self.g_optimizer.param_groups:
            param_group['lr'] = g_lr
        for param_group in self.d_optimizer.param_groups:
            param_group['lr'] = d_lr

    def denorm(self, x):
        """Convert the range from [-1, 1] to [0, 1]."""
        out = (x + 1) / 2
        return out.clamp_(0, 1)

    def gradient_penalty(self, y, x):
        """Compute gradient penalty: (L2_norm(dy/dx) - 1)**2."""
        weight = torch.ones(y.size()).to(self.device)
        dydx = torch.autograd.grad(outputs=y,
                                   inputs=x,
                                   grad_outputs=weight,
                                   retain_graph=True,
                                   create_graph=True,
                                   only_inputs=True)[0]

        dydx = dydx.view(dydx.size(0), -1)
        dydx_l2norm = torch.sqrt(torch.sum(dydx ** 2, dim=1))
        return torch.mean((dydx_l2norm - 1) ** 2)

    def label2onehot(self, labels, dim):
        """Convert label indices to one-hot vectors."""
        batch_size = labels.size(0)
        out = torch.zeros(batch_size, dim)
        out[np.arange(batch_size), labels.long()] = 1
        return out

    def classification_loss(self, logit, target, dataset='CelebA'):
        """Compute binary or softmax cross entropy loss."""
        return F.cross_entropy(logit, target)

    def average_model(self, coeffs=None):
        """Average the central model from each client """
        if not coeffs:
            coeffs = [1/len(self.source_domains) for _ in range(len(self.source_domains))]

        com_G = Generator(self.g_conv_dim, self.c_dim, self.g_repeat_num).to(self.device)
        com_D = Discriminator(self.image_size, self.d_conv_dim, self.c_dim, self.d_repeat_num).to(self.device)

        averaged_D_weights = OrderedDict()
        averaged_G_weights = OrderedDict()
        for i, domain in enumerate(self.source_domains):
            local_D_weight = self.clients_dict[domain].D.state_dict()
            for key in self.D.state_dict().keys():
                if i > 0:
                    averaged_D_weights[key] += coeffs[i] * local_D_weight[key]
                else:
                    averaged_D_weights[key] = coeffs[i] * local_D_weight[key]
            local_G_weight = self.clients_dict[domain].G.state_dict()
            for key in self.G.state_dict().keys():
                if i > 0:
                    averaged_G_weights[key] += coeffs[i] * local_G_weight[key]
                else:
                    averaged_G_weights[key] = coeffs[i] * local_G_weight[key]
        self.D.load_state_dict(averaged_D_weights)
        self.G.load_state_dict(averaged_G_weights)

    def transmit_model(self):
        """Send central model to each client"""
        for domain in self.source_domains:
            self.clients_dict[domain].D = copy.deepcopy(self.D)
            self.clients_dict[domain].G = copy.deepcopy(self.G)

    def init_clients(self):

        # Create clients
        clients_dict = dict()
        for domain in self.source_domains:
            clients_dict[domain] = Client(self.loader_dict[domain], domain, self.config)
        self.clients_dict = clients_dict

        # synchronize the model
        self.transmit_model()

    def fid_wd(self, loader, sg, get_wd, get_fid):
        for x, y, d in loader:
            imgs = x
            labels = y
            domains = d

        domain_list = [0, 1, 2, 3, 4]
        label_list = list(range(10))

        avg_wd = 0
        avg_fid = 0
        for label in label_list:
            print(f'fid_wd/starting for label {label}')
            x_test, d_test = prepare_data_domains(imgs, labels, domains,
                                                  label, domain_list, train=False)
            if sg:
                wd, wd_mat, fid, fig_mat = eval_fid_wd(x_test, d_test, sg, domain_list, self.device, fid=get_fid, wd=get_wd)
            else:
                wd, wd_mat, fid, fig_mat = eval_fid_wd_init(x_test, d_test, domain_list, fid=get_fid, wd=get_wd)
            avg_wd += wd
            avg_fid += fid
            if self.use_wandb:
                wandb.log({f"M/wd/{label}": avg_wd})
                wandb.log({f"M/fid/{label}": avg_fid})
        return avg_wd/len(label_list), avg_fid/len(label_list)

    def run(self):
        """Train StarGAN with FedAvg"""

        # Start training from scratch or resume training.
        start_iters = 0

        # Start training.
        print('Start benchmarking...')
        start_time = time.time()

        ckpt_name = "metrics"
        if self.get_wd:
            ckpt_name += "_wd"
        if self.get_fid:
            ckpt_name += "_fid"
        log_path = os.path.join(self.model_save_dir, f'{ckpt_name}.pkl')
        if os.path.exists(log_path) and False:
            hist_wd, hist_fid, hist_step = pkl.load(open(log_path, 'rb'))
        else:
            hist_wd, hist_fid, hist_step = [], [], []
        hist_params = [0]

        if 0 not in hist_step:
            avg_wd, avg_fid = self.fid_wd(self.test_loader, None, get_fid=self.get_fid, get_wd=self.get_wd)
            hist_wd = [avg_wd] + hist_wd
            hist_fid = [avg_fid] + hist_fid
            hist_step = [0] + hist_step
            print('init calc', hist_step[0], hist_wd[0], hist_fid[0], hist_params[0])
        else:
            print('init done', hist_step[0], hist_wd[0], hist_fid[0], hist_params[0])

        cumm_params = 0
        for i in range(start_iters, self.num_iters):

            if (i + 1) % self.sync_step == 0:
                cumm_params += 2*self.num_params

            if i in hist_step and i != 0:
                index = hist_step.index(i)
                hist_params.append(cumm_params)
                print(index, hist_step[index], hist_wd[index], hist_params[-1])
                continue

            G_path = os.path.join(self.model_save_dir, '{}_domain{}_{}-G.ckpt'.format(self.dataset,
                                                                                      self.target_domain,i+1))
            if not os.path.exists(G_path):
                continue

            self.load_model(G_path)

            # =================================================================================== #
            # 3. Logging                                                                          #
            # =================================================================================== #

            print(G_path)
            avg_wd, avg_fid = self.fid_wd(self.test_loader, self.G, get_fid=self.get_fid, get_wd=self.get_wd)
            print("="*80)
            print('wd:', avg_wd, ', fid:', avg_fid)
            print("="*80)
            hist_wd.append(avg_wd)
            hist_fid.append(avg_fid)
            hist_step.append(i)
            hist_params.append(cumm_params)
            print(hist_step[-1], hist_wd[-1], hist_fid[-1], hist_params[-1])

        print('saving fid distance')
        pkl.dump((hist_wd, hist_fid, hist_step), open(os.path.join(self.model_save_dir, f'{ckpt_name}.pkl'), 'wb'))
        pkl.dump((hist_wd, hist_fid, hist_step, hist_params), open(os.path.join(self.model_save_dir, f'{ckpt_name}_with_params.pkl'), 'wb'))
